Skip to content

Add per-layer hybrid sliding/full attention (Gemma 3 / Gemma 4) to CoreML static LLM export#19251

Open
john-rocky wants to merge 3 commits intopytorch:mainfrom
john-rocky:coreml/per-layer-sliding-window-v2
Open

Add per-layer hybrid sliding/full attention (Gemma 3 / Gemma 4) to CoreML static LLM export#19251
john-rocky wants to merge 3 commits intopytorch:mainfrom
john-rocky:coreml/per-layer-sliding-window-v2

Conversation

@john-rocky
Copy link
Copy Markdown

Summary

Stacked on top of #19250 — that PR caps every layer's KV cache at a single
global window. Gemma 3, Gemma 4, and the Llama 4 Scout family instead
interleave sliding and full attention layers:

  • Gemma 4 E2B: [sliding × 4, full × 1] × 7 = 35 layers (P=5)
  • Gemma 3: [sliding × 5, full × 1] × N (P=6)

HuggingFace expresses this as a single integer sliding_window_pattern,
which is what the new --sliding_window_pattern flag mirrors.

What changed (this PR's commit only)

  • _resolve_per_layer_cache_lens(...) produces a per-layer cache_lens
    list using the HF rule (layer i is full iff (i + 1) % P == 0).
    StaticAttentionIOManager already accepts per-layer cache_lens, so the
    attention mask dict (one mask per unique cache_len) and per-layer KV
    cache shapes fall out for free. The forward pass already keys
    mask = masks[cache_len] per layer, so it picks the right mask without
    any model code change.
  • _get_metadata now reads each cache's cache_len from the example
    tensor's sequence dimension instead of taking a single scalar, so the
    C++ runner metadata reports each layer's actual length under hybrid
    attention.
  • Both the single-method and multifunction export paths route through
    the per-layer resolver.
  • The uniform-sliding behavior introduced by Add --sliding_window flag to CoreML static LLM export #19250 is preserved when
    --sliding_window_pattern is not set.

Why it matters

For Gemma 4 E2B with max_context_len=8192 and --sliding_window 4096 --sliding_window_pattern 5:

Per layer Cache length Layers Total cache (fp16, both K and V)
Sliding 4096 28 28 × 4096 × 2 × n_kv_heads × head_dim × 2B
Full 8160 7 7 × 8160 × 2 × n_kv_heads × head_dim × 2B

vs. naively giving every layer the full 8160-token cache. For the E2B
config (n_kv_heads=1, head_dim=256) that is 86 MB hybrid vs. 143 MB
uniform-full; the savings grow proportionally for E4B.

Review order

This PR contains two commits. The first ('Add --sliding_window flag…')
is identical to #19250 — please merge that one first; the diff on this
PR will then collapse to just the per-layer commit. I'm happy to rebase
once #19250 lands.

Test plan

Added 7 unit tests in examples/apple/coreml/llama/test.py:

  • test_per_layer_cache_lens_uniform_when_no_pattern — back-compat with
    the --sliding_window-only path.
  • test_per_layer_cache_lens_uniform_full_when_no_window — no flag at
    all leaves every layer at max_context_len - input_len.
  • test_per_layer_cache_lens_gemma4_e2b_pattern — 35 layers, P=5
    produces 28 sliding + 7 full in the right positions.
  • test_per_layer_cache_lens_gemma3_pattern — P=6 produces the
    documented [s, s, s, s, s, f, ...] interleave.
  • test_per_layer_cache_lens_pattern_requires_sliding_window — input
    validation.
  • test_per_layer_cache_lens_rejects_pattern_le_one — input validation
    (P=1 would degenerate to all-full and is almost certainly a typo).
  • test_create_example_inputs_with_per_layer_pattern_yields_two_cache_sizes
    — full path: example inputs really do contain caches of both sizes
    and a mask per cache_len.
$ python -m pytest examples/apple/coreml/llama/test.py -v
============================== 13 passed in 3.87s ==============================

(All 6 tests from #19250 still pass alongside the 7 new ones.)

Authored with Claude.

john-rocky added 2 commits May 1, 2026 14:29
Models trained with sliding-window attention (Mistral 7B, Gemma 3,
Gemma 4, Llama 4 Scout, …) only need each layer to attend to the
last `W` tokens, but `export_static_llm_coreml.py` was always
sizing the per-layer KV cache to `max_context_len - input_len`.
That made longer contexts proportionally more expensive in both KV
cache memory and per-token attention compute, even though the model
was trained to ignore everything outside the window.

Add a `--sliding_window` flag that caps the cache at the trained
window.  The downstream pieces — `StaticAttentionMask` invariants
under cache eviction and the `StaticAttentionIOManager`'s per-layer
`cache_lens` plumbing — already support this; the export script
just needed to expose it.  Per-layer mixed sliding/full attention
(Gemma 3/4) is left for a follow-up; this PR uses one window for
every layer.

The cache_len computation is factored into `_resolve_cache_len` so
it is unit-testable, and the README's ANE Optimizations section
documents the new option.

### Memory savings example

For a 32-layer / n_kv_heads=8 / head_dim=128 model exported with
`max_context_len=8192` in fp16, dropping the cache from 8160 to
4096 cuts the per-method KV cache from ~1.07 GB to ~0.54 GB.
Builds on the prior --sliding_window flag.  Gemma 3, Gemma 4, and the
Llama 4 Scout family interleave sliding and full attention layers
rather than using one global setting: Gemma 4 E2B is '4 sliding + 1
full' repeated 7 times across 35 layers; Gemma 3 is '5 sliding + 1
full' repeated.  HuggingFace expresses this as a single integer
`sliding_window_pattern`, which is what the new
`--sliding_window_pattern` flag mirrors.

Implementation:

- `_resolve_per_layer_cache_lens(...)` produces a per-layer cache_lens
  list using the HF rule (layer i is full iff (i+1) % P == 0); the
  IO manager and the model already accept per-layer cache_lens, so the
  attention mask dict and the per-layer KV cache shapes follow.
- `_get_metadata` now reads each cache's cache_len from the example
  tensor's sequence dimension instead of receiving a single scalar,
  so the C++ runner metadata describes each layer correctly under
  hybrid attention.
- Both single-method and multifunction export paths use the per-layer
  resolver.

The previous PR's uniform-sliding behavior is preserved when
`--sliding_window_pattern` is not set.

Authored with Claude.
@john-rocky john-rocky requested a review from metascroy as a code owner May 1, 2026 05:39
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 1, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19251

Note: Links to docs will display an error until the docs builds have been completed.

⚠️ 11 Awaiting Approval

As of commit fe97e43 with merge base 94d2881 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 1, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Verifies a tiny static-attention transformer accepts the heterogeneous
cache shapes produced by _resolve_per_layer_cache_lens and runs a
forward pass without errors — the strongest signal that the model and
IO Manager really do route the right mask per layer under hybrid
sliding/full attention.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant